{ "cells": [ { "cell_type": "markdown", "source": [ "# Custom Statistic\n", "\n", "In this notebook, we'll implement a new statistic as a subclass of LinearFractionalStatistic. All fairrets work with any LinearFractionalStatistic, so they will work with out new statistic as well.\n", "\n", "We take inspiration from the paper \"[Generalizing Group Fairness in Machine Learning via Utilities](https://jair.org/index.php/jair/article/view/14238/26985)\" by Blandin and Kash. For the well-known German Credit Dataset, they propose the following cost for predictions $\\hat{Y}$ and ground truth labels $Y$:\n", "$$C = \\begin{cases}\n", " 0 & \\text{ if } \\hat{Y} = Y \\\\\n", " 1 & \\text{ if } \\hat{Y} = 0 \\wedge Y = 1 \\\\\n", " 5 & \\text{ if } \\hat{Y} = 1 \\wedge Y = 0\n", "\\end{cases}$$\n", "\n", "The costs are motivated by the fact that a loan applicant that receives a loan $(\\hat{Y} = 1)$ but will not repay it $(Y = 0)$ will have to default, which is considered far worse than when an applicant is rejected $(\\hat{Y} = 0)$ that would have repaid $(Y = 1)$ the loan.\n", "\n", "The statistic in this case is the average cost $C$ incurred over all individuals in a sensitive group. Hence, the statistic is canonically formalized as\n", "$$\\gamma(k, f) = \\frac{\\mathbb{E}[SC]}{\\mathbb{E}[S]} = \\frac{\\mathbb{E}[S(1 Y(1 - f(X)) + 5 (1 - Y)f(X))]}{\\mathbb{E}[S]} = \\frac{\\mathbb{E}[S(Y + (5 - 6Y)f(X))]}{\\mathbb{E}[S]}$$\n", "where we filled in $\\hat{Y}$ with the probabilistic $f(X)$.\n", "\n", "The canonical form allows us to identify how the statistic has a linear-fractional form with respect to $f$. Ignoring $S$ for a moment, the intercept of the numerator is $Y$ and the slope is $(5 - 6Y)$. The denominator is not dependent on $f$.\n", "\n", "The statistic is then implemented as:" ], "metadata": { "collapsed": false }, "id": "ff3654068006473d" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import torch\n", "from fairret.statistic import LinearFractionalStatistic\n", "\n", "class CustomCost(LinearFractionalStatistic):\n", " def num_intercept(self, label: torch.Tensor) -> torch.Tensor:\n", " return label\n", "\n", " def num_slope(self, label: torch.Tensor) -> torch.Tensor:\n", " return 5 - 6 * label\n", "\n", " def denom_intercept(self, label: torch.Tensor) -> torch.Tensor:\n", " return 1\n", "\n", " def denom_slope(self, label: torch.Tensor) -> torch.Tensor:\n", " return 0." ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-02T07:45:32.306864700Z", "start_time": "2024-04-02T07:45:29.975983100Z" } }, "id": "initial_id" }, { "cell_type": "markdown", "source": [ "Let's quickly try it out..." ], "metadata": { "collapsed": false }, "id": "d160ad2b65016c65" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "import torch\n", "torch.manual_seed(0)\n", "\n", "feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])\n", "sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])\n", "label = torch.tensor([[0.], [1.], [0.], [1.]])\n", "\n", "from fairret.loss import NormLoss\n", "\n", "statistic = CustomCost()\n", "norm_loss = NormLoss(statistic)\n", "\n", "h_layer_dim = 16\n", "lr = 1e-3\n", "batch_size = 1024\n", "\n", "def build_model():\n", " model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", " )\n", " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", " return model, optimizer\n", "\n", "from torch.utils.data import TensorDataset, DataLoader\n", "dataset = TensorDataset(feat, sens, label)\n", "dataloader = DataLoader(dataset, batch_size=batch_size)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-02T07:45:32.487093100Z", "start_time": "2024-04-02T07:45:32.306864700Z" } }, "id": "e7e46c4842e9d9e5" }, { "cell_type": "markdown", "source": [ "Without fairret..." ], "metadata": { "collapsed": false }, "id": "8ba44232a13a288a" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.7091795206069946\n", "Epoch: 1, loss: 0.7061765193939209\n", "Epoch: 2, loss: 0.7033581733703613\n", "Epoch: 3, loss: 0.7007156610488892\n", "Epoch: 4, loss: 0.6982340812683105\n", "Epoch: 5, loss: 0.6959078907966614\n", "Epoch: 6, loss: 0.6937355995178223\n", "Epoch: 7, loss: 0.6917158365249634\n", "Epoch: 8, loss: 0.6898466944694519\n", "Epoch: 9, loss: 0.6881252527236938\n", "Epoch: 10, loss: 0.6865478754043579\n", "Epoch: 11, loss: 0.6851094961166382\n", "Epoch: 12, loss: 0.6838041543960571\n", "Epoch: 13, loss: 0.6826250553131104\n", "Epoch: 14, loss: 0.6815641522407532\n", "Epoch: 15, loss: 0.6806124448776245\n", "Epoch: 16, loss: 0.6797604560852051\n", "Epoch: 17, loss: 0.6789975762367249\n", "Epoch: 18, loss: 0.6783132553100586\n", "Epoch: 19, loss: 0.6776963472366333\n", "Epoch: 20, loss: 0.6771360039710999\n", "Epoch: 21, loss: 0.6766215562820435\n", "Epoch: 22, loss: 0.6761429309844971\n", "Epoch: 23, loss: 0.6756909489631653\n", "Epoch: 24, loss: 0.6752569675445557\n", "Epoch: 25, loss: 0.6748337745666504\n", "Epoch: 26, loss: 0.674415111541748\n", "Epoch: 27, loss: 0.673996090888977\n", "Epoch: 28, loss: 0.6735726594924927\n", "Epoch: 29, loss: 0.6731564998626709\n", "Epoch: 30, loss: 0.6727579236030579\n", "Epoch: 31, loss: 0.672345757484436\n", "Epoch: 32, loss: 0.6719199419021606\n", "Epoch: 33, loss: 0.6714813709259033\n", "Epoch: 34, loss: 0.6710319519042969\n", "Epoch: 35, loss: 0.6705741882324219\n", "Epoch: 36, loss: 0.6701083779335022\n", "Epoch: 37, loss: 0.669636607170105\n", "Epoch: 38, loss: 0.6691610217094421\n", "Epoch: 39, loss: 0.6686834096908569\n", "Epoch: 40, loss: 0.6682056188583374\n", "Epoch: 41, loss: 0.6677289009094238\n", "Epoch: 42, loss: 0.667254626750946\n", "Epoch: 43, loss: 0.6667835712432861\n", "Epoch: 44, loss: 0.6663164496421814\n", "Epoch: 45, loss: 0.6658533811569214\n", "Epoch: 46, loss: 0.6653945446014404\n", "Epoch: 47, loss: 0.6649397015571594\n", "Epoch: 48, loss: 0.6644884347915649\n", "Epoch: 49, loss: 0.6640403270721436\n", "Epoch: 50, loss: 0.6635947227478027\n", "Epoch: 51, loss: 0.6631510257720947\n", "Epoch: 52, loss: 0.6628453135490417\n", "Epoch: 53, loss: 0.6625917553901672\n", "Epoch: 54, loss: 0.6623181104660034\n", "Epoch: 55, loss: 0.6620256900787354\n", "Epoch: 56, loss: 0.6617173552513123\n", "Epoch: 57, loss: 0.6614043116569519\n", "Epoch: 58, loss: 0.6610796451568604\n", "Epoch: 59, loss: 0.6607442498207092\n", "Epoch: 60, loss: 0.6603990793228149\n", "Epoch: 61, loss: 0.6600450277328491\n", "Epoch: 62, loss: 0.6596829295158386\n", "Epoch: 63, loss: 0.6593135595321655\n", "Epoch: 64, loss: 0.6589376330375671\n", "Epoch: 65, loss: 0.6585558652877808\n", "Epoch: 66, loss: 0.6581688523292542\n", "Epoch: 67, loss: 0.6577771306037903\n", "Epoch: 68, loss: 0.6574320793151855\n", "Epoch: 69, loss: 0.6571431756019592\n", "Epoch: 70, loss: 0.6568371653556824\n", "Epoch: 71, loss: 0.6565203666687012\n", "Epoch: 72, loss: 0.6561905145645142\n", "Epoch: 73, loss: 0.6558488607406616\n", "Epoch: 74, loss: 0.65549635887146\n", "Epoch: 75, loss: 0.6551340818405151\n", "Epoch: 76, loss: 0.6547629237174988\n", "Epoch: 77, loss: 0.6544535160064697\n", "Epoch: 78, loss: 0.6541627645492554\n", "Epoch: 79, loss: 0.6538523435592651\n", "Epoch: 80, loss: 0.6535260677337646\n", "Epoch: 81, loss: 0.6531944274902344\n", "Epoch: 82, loss: 0.6528521776199341\n", "Epoch: 83, loss: 0.6525000333786011\n", "Epoch: 84, loss: 0.652138888835907\n", "Epoch: 85, loss: 0.6518597602844238\n", "Epoch: 86, loss: 0.6515651345252991\n", "Epoch: 87, loss: 0.6512539982795715\n", "Epoch: 88, loss: 0.6509299874305725\n", "Epoch: 89, loss: 0.650594174861908\n", "Epoch: 90, loss: 0.6502466797828674\n", "Epoch: 91, loss: 0.6498894691467285\n", "Epoch: 92, loss: 0.6495950818061829\n", "Epoch: 93, loss: 0.6493034362792969\n", "Epoch: 94, loss: 0.6489962339401245\n", "Epoch: 95, loss: 0.6486777067184448\n", "Epoch: 96, loss: 0.6483432650566101\n", "Epoch: 97, loss: 0.6479994058609009\n", "Epoch: 98, loss: 0.6476455330848694\n", "Epoch: 99, loss: 0.6473514437675476\n", "The CustomCost for group 0 is 1.4111454486846924\n", "The CustomCost for group 1 is 1.690650224685669\n", "The absolute difference is 0.27950477600097656\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 100\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")\n", " \n", "pred = torch.sigmoid(model(feat))\n", "stat_per_group = statistic(pred, sens, label)\n", "absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])\n", "\n", "print(f\"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}\")\n", "print(f\"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}\")\n", "print(f\"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-02T07:45:32.865400900Z", "start_time": "2024-04-02T07:45:32.493318800Z" } }, "id": "bd73c6dbd11d4106" }, { "cell_type": "markdown", "source": [ "With fairret..." ], "metadata": { "collapsed": false }, "id": "dea6c331b680a941" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.7234874963760376\n", "Epoch: 1, loss: 0.7193881869316101\n", "Epoch: 2, loss: 0.7153821587562561\n", "Epoch: 3, loss: 0.7114719748497009\n", "Epoch: 4, loss: 0.7076587677001953\n", "Epoch: 5, loss: 0.703943133354187\n", "Epoch: 6, loss: 0.7003254294395447\n", "Epoch: 7, loss: 0.6968045234680176\n", "Epoch: 8, loss: 0.6933779120445251\n", "Epoch: 9, loss: 0.7009283900260925\n", "Epoch: 10, loss: 0.70442134141922\n", "Epoch: 11, loss: 0.7051329016685486\n", "Epoch: 12, loss: 0.7039238810539246\n", "Epoch: 13, loss: 0.7013260126113892\n", "Epoch: 14, loss: 0.6976962685585022\n", "Epoch: 15, loss: 0.693289577960968\n", "Epoch: 16, loss: 0.6954131722450256\n", "Epoch: 17, loss: 0.6971543431282043\n", "Epoch: 18, loss: 0.6984840035438538\n", "Epoch: 19, loss: 0.6994330883026123\n", "Epoch: 20, loss: 0.7000323534011841\n", "Epoch: 21, loss: 0.700312077999115\n", "Epoch: 22, loss: 0.7003009915351868\n", "Epoch: 23, loss: 0.7000272870063782\n", "Epoch: 24, loss: 0.699517011642456\n", "Epoch: 25, loss: 0.6987953782081604\n", "Epoch: 26, loss: 0.6978861689567566\n", "Epoch: 27, loss: 0.6968110203742981\n", "Epoch: 28, loss: 0.6955899596214294\n", "Epoch: 29, loss: 0.6942419409751892\n", "Epoch: 30, loss: 0.6940961480140686\n", "Epoch: 31, loss: 0.6957410573959351\n", "Epoch: 32, loss: 0.6959663033485413\n", "Epoch: 33, loss: 0.6949725151062012\n", "Epoch: 34, loss: 0.6932772397994995\n", "Epoch: 35, loss: 0.6938610076904297\n", "Epoch: 36, loss: 0.6941598057746887\n", "Epoch: 37, loss: 0.6941966414451599\n", "Epoch: 38, loss: 0.6939942836761475\n", "Epoch: 39, loss: 0.6935751438140869\n", "Epoch: 40, loss: 0.6936072707176208\n", "Epoch: 41, loss: 0.6936568021774292\n", "Epoch: 42, loss: 0.6934249401092529\n", "Epoch: 43, loss: 0.6936383843421936\n", "Epoch: 44, loss: 0.6935983896255493\n", "Epoch: 45, loss: 0.6933280825614929\n", "Epoch: 46, loss: 0.6938256025314331\n", "Epoch: 47, loss: 0.693622350692749\n", "Epoch: 48, loss: 0.69352787733078\n", "Epoch: 49, loss: 0.693831205368042\n", "Epoch: 50, loss: 0.6938722133636475\n", "Epoch: 51, loss: 0.693673849105835\n", "Epoch: 52, loss: 0.6932585835456848\n", "Epoch: 53, loss: 0.694254457950592\n", "Epoch: 54, loss: 0.6943134665489197\n", "Epoch: 55, loss: 0.6932371258735657\n", "Epoch: 56, loss: 0.6940481066703796\n", "Epoch: 57, loss: 0.6946710348129272\n", "Epoch: 58, loss: 0.695000171661377\n", "Epoch: 59, loss: 0.6950598955154419\n", "Epoch: 60, loss: 0.6948744654655457\n", "Epoch: 61, loss: 0.6944674253463745\n", "Epoch: 62, loss: 0.6938610672950745\n", "Epoch: 63, loss: 0.693301260471344\n", "Epoch: 64, loss: 0.6937004923820496\n", "Epoch: 65, loss: 0.6932454705238342\n", "Epoch: 66, loss: 0.6933281421661377\n", "Epoch: 67, loss: 0.6931682229042053\n", "Epoch: 68, loss: 0.6939371824264526\n", "Epoch: 69, loss: 0.6935523152351379\n", "Epoch: 70, loss: 0.6936286687850952\n", "Epoch: 71, loss: 0.693997859954834\n", "Epoch: 72, loss: 0.6940965056419373\n", "Epoch: 73, loss: 0.6939487457275391\n", "Epoch: 74, loss: 0.6935782432556152\n", "Epoch: 75, loss: 0.6934525370597839\n", "Epoch: 76, loss: 0.6934391856193542\n", "Epoch: 77, loss: 0.6935307383537292\n", "Epoch: 78, loss: 0.6937640905380249\n", "Epoch: 79, loss: 0.6937397718429565\n", "Epoch: 80, loss: 0.6934816837310791\n", "Epoch: 81, loss: 0.6934418678283691\n", "Epoch: 82, loss: 0.6932307481765747\n", "Epoch: 83, loss: 0.6937059760093689\n", "Epoch: 84, loss: 0.6940118670463562\n", "Epoch: 85, loss: 0.6940526366233826\n", "Epoch: 86, loss: 0.693852961063385\n", "Epoch: 87, loss: 0.6934364438056946\n", "Epoch: 88, loss: 0.6938560009002686\n", "Epoch: 89, loss: 0.6939221024513245\n", "Epoch: 90, loss: 0.6932809352874756\n", "Epoch: 91, loss: 0.6934866309165955\n", "Epoch: 92, loss: 0.6934378147125244\n", "Epoch: 93, loss: 0.6931588053703308\n", "Epoch: 94, loss: 0.6941953897476196\n", "Epoch: 95, loss: 0.6940174698829651\n", "Epoch: 96, loss: 0.6933366656303406\n", "Epoch: 97, loss: 0.6936314105987549\n", "Epoch: 98, loss: 0.6936633586883545\n", "Epoch: 99, loss: 0.6934562921524048\n", "The CustomCost for group 0 is 1.4996007680892944\n", "The CustomCost for group 1 is 1.4993300437927246\n", "The absolute difference is 0.0002707242965698242\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 100\n", "fairness_strength = 1\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")\n", " \n", "pred = torch.sigmoid(model(feat))\n", "stat_per_group = statistic(pred, sens, label)\n", "absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])\n", "\n", "print(f\"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}\")\n", "print(f\"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}\")\n", "print(f\"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-02T07:45:33.079663500Z", "start_time": "2024-04-02T07:45:32.870584900Z" } }, "id": "4bc786354387b2af" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }